Synphony

Deep Learning Final Project - MSDS Spring Module 2 - 2025

Aditi Puttur & Emma Juan

1. Data Preprocessing¶

In [9]:
import pandas as pd
import numpy as np

import os
import json

from tqdm import tqdm

import re
import unicodedata

import warnings
warnings.filterwarnings("ignore")

from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile
from symusic import Score

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import math
from typing import Optional

import traceback

Loading the data¶

LMD: Midi Files¶

In [ ]:
# Open and read the JSON file
with open('data/LMD/md5_to_paths.json', 'r') as file:
    md5_to_paths = json.load(file)
In [ ]:
md5_to_paths
In [ ]:
lmd_catalog = []

for dirpath, dirnames, filenames in os.walk('data/LMD/lmd_matched'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.mid'):
            lmd_catalog.append(full_path)
In [ ]:
lmd_catalog.sort()
lmd_catalog
In [ ]:
len(lmd_catalog)
In [ ]:
lmd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'LMD_name': []}

lmd_catalog_all['path'] = lmd_catalog
lmd_catalog_all['MSD_name'] = [path.split('/')[-2] for path in lmd_catalog]
lmd_catalog_all['LMD_name'] = [path.split('/')[-1].split('.')[-2] for path in lmd_catalog]

lmd_df = pd.DataFrame(lmd_catalog_all)
lmd_df
In [ ]:
lmd_df["MSD_name"].nunique()

LMD-matched metadata (MillionSongDataset): The Metadata¶

In [ ]:
import hdf5_getters
In [ ]:
msd_catalog = []
titles = []
artists = []
releases = []
years = []

for dirpath, dirnames, filenames in tqdm(os.walk('data/LMD-matched-MSD')):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.h5'):

            # Append the path to the list
            msd_catalog.append(full_path)

            # Get the metadata
            h5 = hdf5_getters.open_h5_file_read(full_path)
            titles.append(hdf5_getters.get_title(h5))
            artists.append(hdf5_getters.get_artist_name(h5))
            releases.append(hdf5_getters.get_release(h5))
            years.append(hdf5_getters.get_year(h5))
            # danceability = hdf5_getters.get_danceability(h5)
            # get_energy = hdf5_getters.get_energy(h5)
In [ ]:
msd_catalog
In [ ]:
len(msd_catalog)
In [ ]:
len(msd_catalog) == lmd_df["MSD_name"].nunique()
In [ ]:
titles[:5]
In [ ]:
artists[:5]
In [ ]:
years[:5]
In [ ]:
titles = [title.decode('utf-8') for title in titles]
artists = [artist.decode('utf-8') for artist in artists]
In [ ]:
msd_catalog_all = {'path': [],
                   'MSD_name': [],
                   'title': [],
                   'artist': [],
                   'year': []}

msd_catalog_all['path'] = msd_catalog
msd_catalog_all['title'] = titles
msd_catalog_all['artist'] = artists
msd_catalog_all['year'] = years
msd_catalog_all['MSD_name'] = [path.split('/')[-1].split('.')[-2] for path in msd_catalog]

msd_df = pd.DataFrame(msd_catalog_all)
msd_df
In [ ]:
msd_df.info()

tagtraum: Adding Genre Tags¶

In [ ]:
tagtraum = {'MSD_name': [],
            'genre': []}

with open("data/tagtraum/msd_tagtraum_cd2c.cls", "r") as file:
    lines = file.readlines()
    for line in lines:
        if not line.startswith('#'):
            track, genre = line.strip().split('\t')
            tagtraum['MSD_name'].append(track)
            tagtraum['genre'].append(genre)
In [ ]:
tagtraum_df = pd.DataFrame(tagtraum)
tagtraum_df
In [ ]:
tagtraum_df["genre"].unique()

Creating our dataset: MIDI + Metadata + Genres¶

Midi + Metadata¶

Each track (MSD_name -> track_id) has one metadata file, and different MIDI files (LMD_name -> midi_id) associated with it.

In [ ]:
len(lmd_df), len(msd_df)
In [ ]:
lmd_df["MSD_name"].nunique(), len(msd_df)
In [ ]:
dataset = lmd_df.merge(msd_df, how="inner", on="MSD_name", suffixes=('_lmd', '_msd'))
dataset = dataset.rename(columns={"path_lmd": "midi_filepath",
                                  "path_msd": "metadata_filepath",
                                  "MSD_name": "track_id",
                                  "LMD_name": "midi_id"})
dataset = dataset[["track_id", "midi_id", "midi_filepath",
                   "title", "artist", "year"]]
dataset
In [ ]:
grouped_dataset = dataset.groupby('track_id').first().reset_index()
grouped_dataset = grouped_dataset[['track_id', 'midi_id', 'midi_filepath']]
grouped_dataset = grouped_dataset.merge(
    dataset[
        ['track_id', "title", "artist", "year"]
    ].drop_duplicates(), on='track_id', how='left' )
grouped_dataset = grouped_dataset[["track_id", "midi_id", "midi_filepath",
                                   "title", "artist", "year"]]
grouped_dataset

Adding the genre tags¶

In [ ]:
dataset = dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
dataset = dataset.drop(columns=["MSD_name"])
dataset
In [ ]:
grouped_dataset = grouped_dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
grouped_dataset = grouped_dataset.drop(columns=["MSD_name"])
grouped_dataset

Sluggifying our parameters¶

In [ ]:
genres = dataset["genre"].unique()
artists = dataset["artist"].unique()
years = dataset["year"].unique()
In [ ]:
def slug(text: str) -> str:
    """Return an ALL_CAPS alnum/underscore version of `text`."""
    # 1) strip accents → ascii
    text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
    # 2) replace non‑alnum with underscore
    text = re.sub(r"[^\w]+", "_", text)
    # 3) collapse multiple underscores and upper‑case
    return re.sub(r"_+", "_", text).strip("_").upper()
In [ ]:
genres_slugged = np.array([slug(genre) for genre in genres])
artists_slugged = np.array([slug(artist) for artist in artists])
years = np.array([int(year) for year in years if not pd.isna(year)])
In [ ]:
genres = pd.DataFrame({
    'genre': genres,
    'slugged_genre': genres_slugged
})

artists = pd.DataFrame({
    'artist': artists,
    'slugged_artist': artists_slugged
})

years = pd.DataFrame({
    'year': years
})
In [ ]:
genres = genres.sort_values(by='genre')
artists = artists.sort_values(by='artist')
years = years.sort_values(by='year')
In [ ]:
dataset["slugged_genre"] = dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
dataset["slugged_artist"] = dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

grouped_dataset["slugged_genre"] = grouped_dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
grouped_dataset["slugged_artist"] = grouped_dataset["artist"].map(artists.set_index('artist')['slugged_artist'])

Saving our data¶

Saving the metadata datasets¶

In [ ]:
dataset.to_csv("data/metadata.csv", index=False)
In [ ]:
grouped_dataset.to_csv("data/grouped_metadata.csv", index=False)

Saving the different parameters to csvs¶

In [ ]:
genres.to_csv("data/genres.csv", index=False)
artists.to_csv("data/artists.csv", index=False)
years.to_csv("data/years.csv", index=False)

2. Model Implementation¶

In [85]:
dataset = pd.read_csv("data/metadata.csv")
grouped_dataset = pd.read_csv("data/grouped_metadata.csv")

genres = pd.read_csv("data/genres.csv")
titles = pd.read_csv("data/titles.csv")
artists = pd.read_csv("data/artists.csv")
years = pd.read_csv("data/years.csv")
In [86]:
genres_slugged = genres["slugged_genre"].values
artists_slugged = artists["slugged_artist"].values
years_vals = years["year"].values
In [87]:
# Config whith which the model was trained
# MAX_TOKENS = 512
# BATCH_SIZE = 2

# D_MODEL    = 512
# N_LAYERS   = 6
# N_HEADS    = 8

# New config to try
MAX_TOKENS = 1024
BATCH_SIZE = 8

D_MODEL = 768
N_LAYERS = 8
N_HEADS = 12 # 768 / 12 = 64 per head

Tokenization¶

Defining the tokenizer¶

In [88]:
config = TokenizerConfig(
    num_velocities=32,
    use_chords=True,
    use_programs=True,
    beat_res={(0,4): 8, (4,8): 4},
    use_rests=True,
    rest_range=(2,8),
    use_time_signatures=True
)

tokenizer = REMI(config)

Adding our special tokens¶

In [89]:
special_toks = \
    [f"<GENRE_{g}>"  for g in genres_slugged] + \
        [f"<ARTIST_{a}>" for a in artists_slugged] + \
            [f"<YEAR_{y}>"   for y in years_vals]  + \
                ["<EOS>", "<PAD>"]

for tok in special_toks:
    tokenizer.add_to_vocab(tok)

Tokenizing: Storing each track as a numpy int32 array.¶

In [90]:
tokenizing = False
In [91]:
# ─── 1. Helpers ──────────────────────────────────────────────────────────
def build_prefix(genre, artist, year, tokenizer):
    """Convert metadata row → list[int] conditioning tokens."""
    genre_tok  = f"<GENRE_{genre}>"
    artist_tok = f"<ARTIST_{artist}>"
    year_tok   = f"<YEAR_{year}>"

    # NOTE: use tokenizer.vocab[...]  (or .token_to_id(...))
    return [
        tokenizer.vocab[genre_tok],
        tokenizer.vocab[artist_tok],
        tokenizer.vocab[year_tok],
    ]

# ─── 3. Output directory -------------------------------------------------
out_dir = "data/tokens/train"

# ─── 4. Iterate files ----------------------------------------------------
if tokenizing:
    rows, _ = grouped_dataset.shape
    for row in tqdm(range(rows)):
        try:
            # 4.0. Get row
            row = grouped_dataset.iloc[row]

            # 4.1. Get MIDI filepath
            midi_path = row["midi_filepath"]

            # 4.2. Get the track ID
            track_id = row["track_id"]

            # 4a. Build CONDITIONING prefix
            genre = row["slugged_genre"]
            artist = row["slugged_artist"]
            year = row["year"]
            prefix_ids = build_prefix(genre, artist, year, tokenizer)          # list[int]

            # 4b. Encode MIDI to tokens
            midi = Score(midi_path)
            midi_tokens = tokenizer(midi)                 # list[int]

            # 4c. Concatenate prefix + midi + <EOS>
            seq_ids = prefix_ids + midi_tokens.ids + [tokenizer.vocab["<EOS>"]]

            # 4d. Save as int32 .npy
            np.save(f"{out_dir}/{track_id}.npy", np.array(seq_ids, dtype=np.int32))
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            traceback.print_exc()
            continue

The Model¶

In [92]:
class RelativePositionalEncoding(nn.Module):
    """
    Sinusoidal *relative‑style* positional encoding.
    The tensor it returns has the same shape as `x`
    so you can just add it:  x + pos(x)

    Args
    ----
    d_model : int            # embedding size
    max_len : int, optional  # maximum sequence length
    """
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Create the (max_len, d_model) sinusoid table once
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * -(math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)          # (L, D)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as a buffer so it moves with .to(device)
        self.register_buffer("pe", pe)              # (L, D)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : Tensor, shape (batch, seq_len, d_model)

        Returns
        -------
        pos : Tensor, same shape as `x`
        """
        seq_len = x.size(1)
        if seq_len > self.max_len:
            raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}")
        # (1, L, D) – broadcast over batch dimension
        return self.pe[:seq_len].unsqueeze(0)
In [93]:
class TransformerDecoderBlock(nn.Module):
    """
    Decoder block that merges causal + pad masking into a (B×H, L, L) float mask,
    so no hidden bool→float blow-ups occur.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        max_len: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim   = d_model,
            num_heads   = n_heads,
            dropout     = dropout,
            batch_first = True,
        )
        self.ln1      = nn.LayerNorm(d_model)
        self.ln2      = nn.LayerNorm(d_model)
        self.ff       = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.dropout  = nn.Dropout(dropout)

        # Precompute float causal mask: 0 on/under diag, -inf above
        causal = torch.triu(
            torch.full((max_len, max_len), float("-inf")),
            diagonal=1
        )
        self.register_buffer("causal_mask", causal, persistent=False)

    def forward(
        self,
        x: torch.Tensor,            # (B, L, D)
        pad_mask: torch.Tensor=None  # (B, L), True=keep token, False=pad
    ) -> torch.Tensor:
        B, L, _ = x.shape
        H       = self.self_attn.num_heads
        device  = x.device
        dtype   = x.dtype

        # 1) slice the (L×L) causal mask
        causal = self.causal_mask[:L, :L]              # float32, (L, L)

        # 2) build a (B, L) float pad mask: 0 on tokens, -inf on pads
        if pad_mask is not None:
            pad_float = torch.zeros((B, L), device=device, dtype=dtype)
            pad_float = pad_float.masked_fill(~pad_mask, float("-inf"))
            # 3) expand pad_float to (B, L, L) and add causal
            #    pad_float.unsqueeze(1): (B, 1, L) → broadcast over src_len
            attn_batch = causal.unsqueeze(0) + pad_float.unsqueeze(1)  # (B, L, L)
        else:
            attn_batch = causal                               # (L, L)

        # 4) if we have a batch, repeat per-head to (B×H, L, L)
        if pad_mask is not None:
            # attn_batch: (B, L, L) → repeat each batch H times
            attn_mask = attn_batch.repeat_interleave(H, dim=0)  # (B*H, L, L)
        else:
            attn_mask = attn_batch   # 2D mask

        # 5) self-attention with ONLY attn_mask
        attn_out, _ = self.self_attn(
            x, x, x,
            attn_mask=attn_mask
        )

        # 6) residual + norm + feed-forward + norm
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x
In [94]:
class Synphony(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = RelativePositionalEncoding(d_model, max_len=2048)
        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x, pad_mask=None):
        x = self.embed(x) + self.pos(x)
        for blk in self.blocks:
            x = blk(x, pad_mask)
        x = self.ln(x)
        return self.out(x)

The Training Loop¶

In [95]:
from torch.utils.data import Dataset, DataLoader

import random
random.seed(42)  # For reproducibility
In [96]:
tok_paths = []

for dirpath, dirnames, filenames in os.walk('data/tokens/train'):
    for file in filenames:
        full_path = os.path.join(dirpath, file)
        if full_path.endswith('.npy'):
            tok_paths.append(full_path)
In [97]:
len(tok_paths)
Out[97]:
6150
In [98]:
split_index = int(len(tok_paths) * 0.8)  # 80% train, 20% test
random.shuffle(tok_paths)

train_paths = tok_paths[:split_index]
test_paths = tok_paths[split_index:]
In [104]:
# ─── 1. Dataset + collate ────────────────────────────────────────────────
class MidiTokenDataset(Dataset):
    def __init__(self, npy_paths):
        self.paths = npy_paths

    def __len__(self):               # number of songs in split
        return len(self.paths)

    def __getitem__(self, idx):      # returns 1‑D np.ndarray[int]
        return np.load(self.paths[idx]).astype(np.int64)

def collate_fn(batch, pad_id):
    B, L = len(batch), MAX_TOKENS
    x = torch.full((B, L), pad_id, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq = torch.from_numpy(seq)
        if seq.numel() > L:
            start = torch.randint(0, seq.numel() - L + 1, (1,)).item()
            seq = seq[start : start + L]
        x[i, : seq.numel()] = seq
    pad_mask = ~x.eq(pad_id)
    return x, pad_mask


# ─── 2. DataLoaders ──────────────────────────────────────────────────────
PAD_ID = tokenizer.vocab['<PAD>']          # or use the ID you chose for <PAD>

train_ds = MidiTokenDataset(train_paths)
val_ds   = MidiTokenDataset(test_paths)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)
val_loader   = DataLoader(
    val_ds,   batch_size=BATCH_SIZE, shuffle=False,
    collate_fn=lambda b: collate_fn(b, PAD_ID)
)

# ─── 3. Model, optimiser, scheduler ─────────────────────────────────────
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

model = Synphony(
    vocab_size=len(tokenizer), d_model=D_MODEL,
    n_layers=N_LAYERS, n_heads=N_HEADS).to(device)

# 1. Switch to AdamW with weight decay
optim = torch.optim.AdamW(model.parameters(),
                          lr=3e-4,           # whatever your current LR is
                          weight_decay=1e-2) # small wd to regularize

# 2. Set up a Reduce-on-Plateau scheduler
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                  mode='min',        # val loss should go down
                                                  factor=0.5,        # cut LR in half
                                                  patience=2,        # wait 2 epochs
                                                  min_lr=1e-6,       # floor on LR
                                                  verbose=True)


# ─── 4. Training loop ────────────────────────────────────────────────────
best_val_loss = float("inf")

for epoch in tqdm(range(1, 51)):                         # 50 epochs
    # ---- train ----------------------------------------------------------
    model.train()
    running_loss = 0.0

    for x, pad_mask in train_loader:          # pad_mask: (B, L)
        x, pad_mask = x.to(device), pad_mask.to(device)

        logits = model(x[:, :-1], pad_mask=pad_mask[:, :-1])

        loss   = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            x[:, 1:].reshape(-1),
            ignore_index=PAD_ID,
            label_smoothing=0.1
        )

        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        running_loss += loss.item()

    train_ppl = math.exp(running_loss / len(train_loader))

    # ---- validation -----------------------------------------------------
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x, pad_mask in val_loader:             # pad_mask is (B, L)
            x, pad_mask = x.to(device), pad_mask.to(device)

            # exactly like in training
            logits  = model(x[:, :-1], pad_mask=pad_mask[:, :-1])
            val_loss += F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                x[:, 1:].reshape(-1),
                ignore_index=PAD_ID
            ).item()

    val_ppl = math.exp(val_loss / len(val_loader))
    print(f"val PPL {val_ppl:6.2f}")
    print(f"Epoch {epoch:02d} ▸ train PPL {train_ppl:6.2f} | val PPL {val_ppl:6.2f}")
    
    # ---- scheduler step -----------------------------------------------
    sched.step(val_loss / len(val_loader))  # pass your avg val_loss
    
    # log current LR
    current_lr = optim.param_groups[0]['lr']
    print(f"         ↳ LR now = {current_lr:.2e}")

    # ---- checkpoint -----------------------------------------------------
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "synphony_best.pt")
        print("  ✓ new best model saved")

print("Done!")
  0%|          | 0/50 [00:00<?, ?it/s]
val PPL  12.04
Epoch 01 ▸ train PPL  41.38 | val PPL  12.04
         ↳ LR now = 3.00e-04
  2%|▏         | 1/50 [08:14<6:43:27, 494.03s/it]
  ✓ new best model saved
val PPL   6.79
Epoch 02 ▸ train PPL  21.59 | val PPL   6.79
         ↳ LR now = 3.00e-04
  4%|▍         | 2/50 [16:31<6:36:43, 495.90s/it]
  ✓ new best model saved
val PPL   4.01
Epoch 03 ▸ train PPL  12.46 | val PPL   4.01
         ↳ LR now = 3.00e-04
  6%|▌         | 3/50 [24:48<6:28:57, 496.54s/it]
  ✓ new best model saved
val PPL   3.48
Epoch 04 ▸ train PPL   9.59 | val PPL   3.48
         ↳ LR now = 3.00e-04
  8%|▊         | 4/50 [33:05<6:20:55, 496.85s/it]
  ✓ new best model saved
val PPL   3.28
Epoch 05 ▸ train PPL   8.84 | val PPL   3.28
         ↳ LR now = 3.00e-04
 10%|█         | 5/50 [41:23<6:12:51, 497.16s/it]
  ✓ new best model saved
val PPL   3.17
Epoch 06 ▸ train PPL   8.44 | val PPL   3.17
         ↳ LR now = 3.00e-04
 12%|█▏        | 6/50 [49:40<6:04:39, 497.25s/it]
  ✓ new best model saved
val PPL   3.06
Epoch 07 ▸ train PPL   8.14 | val PPL   3.06
         ↳ LR now = 3.00e-04
 14%|█▍        | 7/50 [57:58<5:56:26, 497.36s/it]
  ✓ new best model saved
val PPL   3.03
Epoch 08 ▸ train PPL   7.97 | val PPL   3.03
         ↳ LR now = 3.00e-04
 16%|█▌        | 8/50 [1:06:15<5:48:09, 497.36s/it]
  ✓ new best model saved
val PPL   2.97
Epoch 09 ▸ train PPL   7.81 | val PPL   2.97
         ↳ LR now = 3.00e-04
 18%|█▊        | 9/50 [1:14:33<5:39:53, 497.39s/it]
  ✓ new best model saved
val PPL   2.92
Epoch 10 ▸ train PPL   7.64 | val PPL   2.92
         ↳ LR now = 3.00e-04
 20%|██        | 10/50 [1:22:50<5:31:34, 497.37s/it]
  ✓ new best model saved
val PPL   2.87
Epoch 11 ▸ train PPL   7.55 | val PPL   2.87
         ↳ LR now = 3.00e-04
 22%|██▏       | 11/50 [1:31:08<5:23:17, 497.38s/it]
  ✓ new best model saved
 24%|██▍       | 12/50 [1:39:24<5:14:51, 497.14s/it]
val PPL   2.87
Epoch 12 ▸ train PPL   7.42 | val PPL   2.87
         ↳ LR now = 3.00e-04
val PPL   2.82
Epoch 13 ▸ train PPL   7.35 | val PPL   2.82
         ↳ LR now = 3.00e-04
 26%|██▌       | 13/50 [1:47:42<5:06:43, 497.40s/it]
  ✓ new best model saved
val PPL   2.80
Epoch 14 ▸ train PPL   7.28 | val PPL   2.80
         ↳ LR now = 3.00e-04
 28%|██▊       | 14/50 [1:56:00<4:58:31, 497.54s/it]
  ✓ new best model saved
val PPL   2.77
Epoch 15 ▸ train PPL   7.23 | val PPL   2.77
         ↳ LR now = 3.00e-04
 30%|███       | 15/50 [2:04:18<4:50:18, 497.66s/it]
  ✓ new best model saved
val PPL   2.76
Epoch 16 ▸ train PPL   7.15 | val PPL   2.76
         ↳ LR now = 3.00e-04
 32%|███▏      | 16/50 [2:12:36<4:42:01, 497.70s/it]
  ✓ new best model saved
val PPL   2.74
Epoch 17 ▸ train PPL   7.10 | val PPL   2.74
         ↳ LR now = 3.00e-04
 34%|███▍      | 17/50 [2:20:54<4:33:44, 497.73s/it]
  ✓ new best model saved
val PPL   2.71
Epoch 18 ▸ train PPL   7.05 | val PPL   2.71
         ↳ LR now = 3.00e-04
 36%|███▌      | 18/50 [2:29:11<4:25:23, 497.62s/it]
  ✓ new best model saved
val PPL   2.67
Epoch 19 ▸ train PPL   7.00 | val PPL   2.67
         ↳ LR now = 3.00e-04
 38%|███▊      | 19/50 [2:37:29<4:17:08, 497.69s/it]
  ✓ new best model saved
 40%|████      | 20/50 [2:45:45<4:08:35, 497.17s/it]
val PPL   2.70
Epoch 20 ▸ train PPL   6.96 | val PPL   2.70
         ↳ LR now = 3.00e-04
 42%|████▏     | 21/50 [2:54:01<4:00:11, 496.95s/it]
val PPL   2.68
Epoch 21 ▸ train PPL   6.92 | val PPL   2.68
         ↳ LR now = 3.00e-04
val PPL   2.66
Epoch 22 ▸ train PPL   6.88 | val PPL   2.66
         ↳ LR now = 3.00e-04
 44%|████▍     | 22/50 [3:02:19<3:52:00, 497.15s/it]
  ✓ new best model saved
 46%|████▌     | 23/50 [3:10:35<3:43:35, 496.88s/it]
val PPL   2.66
Epoch 23 ▸ train PPL   6.82 | val PPL   2.66
         ↳ LR now = 3.00e-04
val PPL   2.63
Epoch 24 ▸ train PPL   6.80 | val PPL   2.63
         ↳ LR now = 3.00e-04
 48%|████▊     | 24/50 [3:18:53<3:35:24, 497.11s/it]
  ✓ new best model saved
 50%|█████     | 25/50 [3:27:09<3:26:59, 496.79s/it]
val PPL   2.65
Epoch 25 ▸ train PPL   6.79 | val PPL   2.65
         ↳ LR now = 3.00e-04
val PPL   2.62
Epoch 26 ▸ train PPL   6.75 | val PPL   2.62
         ↳ LR now = 3.00e-04
 52%|█████▏    | 26/50 [3:35:26<3:18:49, 497.05s/it]
  ✓ new best model saved
 54%|█████▍    | 27/50 [3:43:42<3:10:24, 496.72s/it]
val PPL   2.63
Epoch 27 ▸ train PPL   6.73 | val PPL   2.63
         ↳ LR now = 3.00e-04
val PPL   2.61
Epoch 28 ▸ train PPL   6.68 | val PPL   2.61
         ↳ LR now = 3.00e-04
 56%|█████▌    | 28/50 [3:52:00<3:02:12, 496.93s/it]
  ✓ new best model saved
 58%|█████▊    | 29/50 [4:00:16<2:53:51, 496.73s/it]
val PPL   2.61
Epoch 29 ▸ train PPL   6.67 | val PPL   2.61
         ↳ LR now = 3.00e-04
val PPL   2.60
Epoch 30 ▸ train PPL   6.64 | val PPL   2.60
         ↳ LR now = 3.00e-04
 60%|██████    | 30/50 [4:08:34<2:45:42, 497.11s/it]
  ✓ new best model saved
val PPL   2.57
Epoch 31 ▸ train PPL   6.65 | val PPL   2.57
         ↳ LR now = 3.00e-04
 62%|██████▏   | 31/50 [4:16:52<2:37:28, 497.31s/it]
  ✓ new best model saved
 64%|██████▍   | 32/50 [4:25:08<2:29:05, 496.96s/it]
val PPL   2.59
Epoch 32 ▸ train PPL   6.59 | val PPL   2.59
         ↳ LR now = 3.00e-04
 66%|██████▌   | 33/50 [4:33:24<2:20:44, 496.72s/it]
val PPL   2.58
Epoch 33 ▸ train PPL   6.57 | val PPL   2.58
         ↳ LR now = 3.00e-04
val PPL   2.55
Epoch 34 ▸ train PPL   6.56 | val PPL   2.55
         ↳ LR now = 3.00e-04
 68%|██████▊   | 34/50 [4:41:42<2:12:32, 497.02s/it]
  ✓ new best model saved
 70%|███████   | 35/50 [4:49:58<2:04:11, 496.74s/it]
val PPL   2.57
Epoch 35 ▸ train PPL   6.53 | val PPL   2.57
         ↳ LR now = 3.00e-04
 72%|███████▏  | 36/50 [4:58:14<1:55:51, 496.52s/it]
val PPL   2.57
Epoch 36 ▸ train PPL   6.53 | val PPL   2.57
         ↳ LR now = 3.00e-04
val PPL   2.55
Epoch 37 ▸ train PPL   6.53 | val PPL   2.55
         ↳ LR now = 3.00e-04
 74%|███████▍  | 37/50 [5:06:31<1:47:38, 496.79s/it]
  ✓ new best model saved
val PPL   2.54
Epoch 38 ▸ train PPL   6.49 | val PPL   2.54
         ↳ LR now = 3.00e-04
 76%|███████▌  | 38/50 [5:14:49<1:39:24, 497.00s/it]
  ✓ new best model saved
 78%|███████▊  | 39/50 [5:23:05<1:31:03, 496.68s/it]
val PPL   2.55
Epoch 39 ▸ train PPL   6.45 | val PPL   2.55
         ↳ LR now = 3.00e-04
 80%|████████  | 40/50 [5:31:21<1:22:44, 496.46s/it]
val PPL   2.55
Epoch 40 ▸ train PPL   6.45 | val PPL   2.55
         ↳ LR now = 3.00e-04
val PPL   2.54
Epoch 41 ▸ train PPL   6.44 | val PPL   2.54
         ↳ LR now = 3.00e-04
 82%|████████▏ | 41/50 [5:39:38<1:14:30, 496.69s/it]
  ✓ new best model saved
val PPL   2.52
Epoch 42 ▸ train PPL   6.40 | val PPL   2.52
         ↳ LR now = 3.00e-04
 84%|████████▍ | 42/50 [5:47:55<1:06:15, 496.91s/it]
  ✓ new best model saved
 86%|████████▌ | 43/50 [5:56:11<57:56, 496.60s/it]  
val PPL   2.52
Epoch 43 ▸ train PPL   6.41 | val PPL   2.52
         ↳ LR now = 3.00e-04
 88%|████████▊ | 44/50 [6:04:27<49:38, 496.37s/it]
val PPL   2.53
Epoch 44 ▸ train PPL   6.38 | val PPL   2.53
         ↳ LR now = 3.00e-04
val PPL   2.52
Epoch 45 ▸ train PPL   6.38 | val PPL   2.52
         ↳ LR now = 1.50e-04
 90%|█████████ | 45/50 [6:12:45<41:23, 496.68s/it]
  ✓ new best model saved
val PPL   2.46
Epoch 46 ▸ train PPL   6.20 | val PPL   2.46
         ↳ LR now = 1.50e-04
 92%|█████████▏| 46/50 [6:21:02<33:07, 496.86s/it]
  ✓ new best model saved
val PPL   2.45
Epoch 47 ▸ train PPL   6.12 | val PPL   2.45
         ↳ LR now = 1.50e-04
 94%|█████████▍| 47/50 [6:29:19<24:50, 496.93s/it]
  ✓ new best model saved
 96%|█████████▌| 48/50 [6:37:35<16:33, 496.73s/it]
val PPL   2.46
Epoch 48 ▸ train PPL   6.09 | val PPL   2.46
         ↳ LR now = 1.50e-04
 98%|█████████▊| 49/50 [6:45:51<08:16, 496.51s/it]
val PPL   2.45
Epoch 49 ▸ train PPL   6.06 | val PPL   2.45
         ↳ LR now = 1.50e-04
val PPL   2.43
Epoch 50 ▸ train PPL   6.04 | val PPL   2.43
         ↳ LR now = 1.50e-04
100%|██████████| 50/50 [6:54:09<00:00, 496.99s/it]
  ✓ new best model saved
Done!

In [72]:
tokenizer.vocab_size
Out[72]:
3534

3. Model Inference¶

In [105]:
model.eval()
Out[105]:
Synphony(
  (embed): Embedding(3534, 768)
  (pos): RelativePositionalEncoding()
  (blocks): ModuleList(
    (0-7): 8 x TransformerDecoderBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (out): Linear(in_features=768, out_features=3534, bias=True)
)
In [106]:
TEMPERATURE = 1.0
TOP_K = 8

# ─── 2. Helper for top-k filtering ───────────────────────────────────────
def top_k_logits(logits, k):
    v, _ = torch.topk(logits, k)
    threshold = v[-1]
    return torch.where(logits < threshold, torch.full_like(logits, -float("Inf")), logits)

# ─── 3. Autoregressive generation ────────────────────────────────────────
@torch.no_grad()
def generate(
        genre:str,
        artist:str,
        year:int,
        max_length:int = MAX_TOKENS
    ) -> list[int]:
    prefix = build_prefix(genre, artist, year, tokenizer)
    input_ids = torch.tensor([prefix], device=device)  # (1, P)
    pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    for _ in tqdm(range(max_length - len(prefix))):
        logits = model(input_ids, pad_mask=pad_mask)
        next_logits = logits[0, -1, :]                  # (V,)
        next_logits = next_logits / TEMPERATURE
        next_logits = top_k_logits(next_logits, TOP_K)
        probs = F.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # (1,)
        if next_id.item() == tokenizer.vocab["<EOS>"]:
            break

        # append and extend pad_mask
        input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1)   # (1, L+1)
        pad_mask  = torch.ones_like(input_ids, dtype=torch.bool, device=device)

    return input_ids[0].tolist()

# ─── 4. Decode to MIDI & save ────────────────────────────────────────────
def tokens_to_midi(token_ids: list[int], out_path: str):
    """
    Drop the 3 metadata tokens + optional EOS, then decode the rest.
    """
    # 1) drop the first 3 prefix IDs (genre, artist, year)
    musical_ids = token_ids[3:]
    # 2) drop trailing <EOS> if present
    eos_id = tokenizer.vocab["<EOS>"]
    if len(musical_ids) > 0 and musical_ids[-1] == eos_id:
        musical_ids = musical_ids[:-1]

    # 3) decode only the musical tokens back to a PrettyMIDI
    pm = tokenizer(musical_ids)
    # 4) write out the .mid file
    pm.dump_midi(out_path)
In [120]:
# ─── 5. Run it! ───────────────────────────────────────────────────────────
# Example user inputs
genre_input  = "ROCK"
artist_input = "GLORIA_GAYNOR"
year_input   = 1990

gen_ids = generate(genre_input, artist_input, year_input, max_length=512)
out_file = "generated.mid"
tokens_to_midi(gen_ids, out_file)
print(f"🎹 Wrote MIDI to {out_file}")
100%|██████████| 509/509 [00:03<00:00, 131.18it/s]
🎹 Wrote MIDI to generated.mid

In [121]:
from midi2audio import FluidSynth
from IPython.display import Audio

# render your MIDI to a WAV
fs = FluidSynth()
fs.midi_to_audio('generated.mid', 'generated.wav')

# now embed the WAV inline
Audio('generated.wav')
Parameter '/home/jupyter/.fluidsynth/default_sound_font.sf2' not a SoundFont or MIDI file or error occurred identifying it.
FluidSynth runtime version 2.1.7
Copyright (C) 2000-2021 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file 'generated.wav'..
Out[121]:
Your browser does not support the audio element.
In [ ]: